"""TensorFlow interface for third-party optimizers.

Code below is taken from https://github.com/tensorflow/tensorflow/blob/v1.15.2/tensorflow/contrib/opt/python/training/external_optimizer.py,
because the ``tf.contrib`` module is not included in TensorFlow 2.

Another solution is using TensorFlow Probability, see the following references.
But the following solution requires setting the weights before building the network and loss,
which is not consistent with other optimizers in graph mode.
A possible solution Could be adding a TFPOptimizerInterface similar to ScipyOptimizerInterface.

- https://www.tensorflow.org/probability/api_docs/python/tfp/optimizer/lbfgs_minimize
- https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/lbfgs_test.py
- https://stackoverflow.com/questions/58591562/how-can-we-use-lbfgs-minimize-in-tensorflow-2-0
- https://stackoverflow.com/questions/59029854/use-scipy-optimizer-with-tensorflow-2-0-for-neural-network-training
- https://pychao.com/2019/11/02/optimize-tensorflow-keras-models-with-l-bfgs-from-tensorflow-probability/
- https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993
- https://github.com/pierremtb/PINNs-TF2.0/blob/master/utils/neuralnetwork.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from .backend import tf

__all__ = ["ExternalOptimizerInterface", "ScipyOptimizerInterface"]


class ExternalOptimizerInterface(object):
    """Base class for interfaces with external optimization algorithms.
    Subclass this and implement `_minimize` in order to wrap a new optimization
    algorithm.
    `ExternalOptimizerInterface` should not be instantiated directly; instead use
    e.g. `ScipyOptimizerInterface`.
    @@__init__
    @@minimize
    """

    def __init__(
        self,
        loss,
        var_list=None,
        equalities=None,
        inequalities=None,
        var_to_bounds=None,
        **optimizer_kwargs
    ):
        """Initialize a new interface instance.
        Args:
          loss: A scalar `Tensor` to be minimized.
          var_list: Optional `list` of `Variable` objects to update to minimize
            `loss`.  Defaults to the list of variables collected in the graph
            under the key `GraphKeys.TRAINABLE_VARIABLES`.
          equalities: Optional `list` of equality constraint scalar `Tensor`s to be
            held equal to zero.
          inequalities: Optional `list` of inequality constraint scalar `Tensor`s
            to be held nonnegative.
          var_to_bounds: Optional `dict` where each key is an optimization
            `Variable` and each corresponding value is a length-2 tuple of
            `(low, high)` bounds. Although enforcing this kind of simple constraint
            could be accomplished with the `inequalities` arg, not all optimization
            algorithms support general inequality constraints, e.g. L-BFGS-B. Both
            `low` and `high` can either be numbers or anything convertible to a
            NumPy array that can be broadcast to the shape of `var` (using
            `np.broadcast_to`). To indicate that there is no bound, use `None` (or
            `+/- np.infty`). For example, if `var` is a 2x3 matrix, then any of
            the following corresponding `bounds` could be supplied:
            * `(0, np.infty)`: Each element of `var` held positive.
            * `(-np.infty, [1, 2])`: First column less than 1, second column less
              than 2.
            * `(-np.infty, [[1], [2], [3]])`: First row less than 1, second row less
              than 2, etc.
            * `(-np.infty, [[1, 2, 3], [4, 5, 6]])`: Entry `var[0, 0]` less than 1,
              `var[0, 1]` less than 2, etc.
          **optimizer_kwargs: Other subclass-specific keyword arguments.
        """
        self._loss = loss
        self._equalities = equalities or []
        self._inequalities = inequalities or []

        if var_list is None:
            self._vars = tf.trainable_variables()
        else:
            self._vars = list(var_list)

        packed_bounds = None
        if var_to_bounds is not None:
            left_packed_bounds = []
            right_packed_bounds = []
            for var in self._vars:
                shape = var.get_shape().as_list()
                bounds = (-np.infty, np.infty)
                if var in var_to_bounds:
                    bounds = var_to_bounds[var]
                left_packed_bounds.extend(list(np.broadcast_to(bounds[0], shape).flat))
                right_packed_bounds.extend(list(np.broadcast_to(bounds[1], shape).flat))
            packed_bounds = list(zip(left_packed_bounds, right_packed_bounds))
        self._packed_bounds = packed_bounds

        self._update_placeholders = [tf.placeholder(var.dtype) for var in self._vars]
        self._var_updates = [
            var.assign(tf.reshape(placeholder, _get_shape_tuple(var)))
            for var, placeholder in zip(self._vars, self._update_placeholders)
        ]

        loss_grads = _compute_gradients(loss, self._vars)
        equalities_grads = [
            _compute_gradients(equality, self._vars) for equality in self._equalities
        ]
        inequalities_grads = [
            _compute_gradients(inequality, self._vars)
            for inequality in self._inequalities
        ]

        self.optimizer_kwargs = optimizer_kwargs

        self._packed_var = self._pack(self._vars)
        self._packed_loss_grad = self._pack(loss_grads)
        self._packed_equality_grads = [
            self._pack(equality_grads) for equality_grads in equalities_grads
        ]
        self._packed_inequality_grads = [
            self._pack(inequality_grads) for inequality_grads in inequalities_grads
        ]

        dims = [_prod(_get_shape_tuple(var)) for var in self._vars]
        accumulated_dims = list(_accumulate(dims))
        self._packing_slices = [
            slice(start, end)
            for start, end in zip(accumulated_dims[:-1], accumulated_dims[1:])
        ]

    def minimize(
        self,
        session=None,
        feed_dict=None,
        fetches=None,
        step_callback=None,
        loss_callback=None,
        **run_kwargs
    ):
        """Minimize a scalar `Tensor`.
        Variables subject to optimization are updated in-place at the end of
        optimization.
        Note that this method does *not* just return a minimization `Op`, unlike
        `Optimizer.minimize()`; instead it actually performs minimization by
        executing commands to control a `Session`.
        Args:
          session: A `Session` instance.
          feed_dict: A feed dict to be passed to calls to `session.run`.
          fetches: A list of `Tensor`s to fetch and supply to `loss_callback`
            as positional arguments.
          step_callback: A function to be called at each optimization step;
            arguments are the current values of all optimization variables
            flattened into a single vector.
          loss_callback: A function to be called every time the loss and gradients
            are computed, with evaluated fetches supplied as positional arguments.
          **run_kwargs: kwargs to pass to `session.run`.
        """
        session = session or tf.get_default_session()
        feed_dict = feed_dict or {}
        fetches = fetches or []

        loss_callback = loss_callback or (lambda *fetches: None)
        step_callback = step_callback or (lambda xk: None)

        # Construct loss function and associated gradient.
        loss_grad_func = self._make_eval_func(
            [self._loss, self._packed_loss_grad],
            session,
            feed_dict,
            fetches,
            loss_callback,
        )

        # Construct equality constraint functions and associated gradients.
        equality_funcs = self._make_eval_funcs(
            self._equalities, session, feed_dict, fetches
        )
        equality_grad_funcs = self._make_eval_funcs(
            self._packed_equality_grads, session, feed_dict, fetches
        )

        # Construct inequality constraint functions and associated gradients.
        inequality_funcs = self._make_eval_funcs(
            self._inequalities, session, feed_dict, fetches
        )
        inequality_grad_funcs = self._make_eval_funcs(
            self._packed_inequality_grads, session, feed_dict, fetches
        )

        # Get initial value from TF session.
        initial_packed_var_val = session.run(self._packed_var)

        # Perform minimization.
        packed_var_val = self._minimize(
            initial_val=initial_packed_var_val,
            loss_grad_func=loss_grad_func,
            equality_funcs=equality_funcs,
            equality_grad_funcs=equality_grad_funcs,
            inequality_funcs=inequality_funcs,
            inequality_grad_funcs=inequality_grad_funcs,
            packed_bounds=self._packed_bounds,
            step_callback=step_callback,
            optimizer_kwargs=self.optimizer_kwargs,
        )
        var_vals = [
            packed_var_val[packing_slice] for packing_slice in self._packing_slices
        ]

        # Set optimization variables to their new values.
        session.run(
            self._var_updates,
            feed_dict=dict(zip(self._update_placeholders, var_vals)),
            **run_kwargs
        )

    def _minimize(
        self,
        initial_val,
        loss_grad_func,
        equality_funcs,
        equality_grad_funcs,
        inequality_funcs,
        inequality_grad_funcs,
        packed_bounds,
        step_callback,
        optimizer_kwargs,
    ):
        """Wrapper for a particular optimization algorithm implementation.
        It would be appropriate for a subclass implementation of this method to
        raise `NotImplementedError` if unsupported arguments are passed: e.g. if an
        algorithm does not support constraints but `len(equality_funcs) > 0`.
        Args:
          initial_val: A NumPy vector of initial values.
          loss_grad_func: A function accepting a NumPy packed variable vector and
            returning two outputs, a loss value and the gradient of that loss with
            respect to the packed variable vector.
          equality_funcs: A list of functions each of which specifies a scalar
            quantity that an optimizer should hold exactly zero.
          equality_grad_funcs: A list of gradients of equality_funcs.
          inequality_funcs: A list of functions each of which specifies a scalar
            quantity that an optimizer should hold >= 0.
          inequality_grad_funcs: A list of gradients of inequality_funcs.
          packed_bounds: A list of bounds for each index, or `None`.
          step_callback: A callback function to execute at each optimization step,
            supplied with the current value of the packed variable vector.
          optimizer_kwargs: Other key-value arguments available to the optimizer.
        Returns:
          The optimal variable vector as a NumPy vector.
        """
        raise NotImplementedError(
            "To use ExternalOptimizerInterface, subclass from it and implement "
            "the _minimize() method."
        )

    @classmethod
    def _pack(cls, tensors):
        """Pack a list of `Tensor`s into a single, flattened, rank-1 `Tensor`."""
        if not tensors:
            return None
        elif len(tensors) == 1:
            return tf.reshape(tensors[0], [-1])
        else:
            flattened = [tf.reshape(tensor, [-1]) for tensor in tensors]
            return tf.concat(flattened, 0)

    def _make_eval_func(self, tensors, session, feed_dict, fetches, callback=None):
        """Construct a function that evaluates a `Tensor` or list of `Tensor`s."""
        if not isinstance(tensors, list):
            tensors = [tensors]
        num_tensors = len(tensors)

        def eval_func(x):
            """Function to evaluate a `Tensor`."""
            augmented_feed_dict = {
                var: x[packing_slice].reshape(_get_shape_tuple(var))
                for var, packing_slice in zip(self._vars, self._packing_slices)
            }
            augmented_feed_dict.update(feed_dict)
            augmented_fetches = tensors + fetches

            augmented_fetch_vals = session.run(
                augmented_fetches, feed_dict=augmented_feed_dict
            )

            if callable(callback):
                callback(*augmented_fetch_vals[num_tensors:])

            return augmented_fetch_vals[:num_tensors]

        return eval_func

    def _make_eval_funcs(self, tensors, session, feed_dict, fetches, callback=None):
        return [
            self._make_eval_func(tensor, session, feed_dict, fetches, callback)
            for tensor in tensors
        ]


class ScipyOptimizerInterface(ExternalOptimizerInterface):
    """Wrapper allowing `scipy.optimize.minimize` to operate a `tf.compat.v1.Session`.
    Example:
    ```python
    vector = tf.Variable([7., 7.], 'vector')
    # Make vector norm as small as possible.
    loss = tf.reduce_sum(tf.square(vector))
    optimizer = ScipyOptimizerInterface(loss, options={'maxiter': 100})
    with tf.compat.v1.Session() as session:
      optimizer.minimize(session)
    # The value of vector should now be [0., 0.].
    ```
    Example with simple bound constraints:
    ```python
    vector = tf.Variable([7., 7.], 'vector')
    # Make vector norm as small as possible.
    loss = tf.reduce_sum(tf.square(vector))
    optimizer = ScipyOptimizerInterface(
        loss, var_to_bounds={vector: ([1, 2], np.infty)})
    with tf.compat.v1.Session() as session:
      optimizer.minimize(session)
    # The value of vector should now be [1., 2.].
    ```
    Example with more complicated constraints:
    ```python
    vector = tf.Variable([7., 7.], 'vector')
    # Make vector norm as small as possible.
    loss = tf.reduce_sum(tf.square(vector))
    # Ensure the vector's y component is = 1.
    equalities = [vector[1] - 1.]
    # Ensure the vector's x component is >= 1.
    inequalities = [vector[0] - 1.]
    # Our default SciPy optimization algorithm, L-BFGS-B, does not support
    # general constraints. Thus we use SLSQP instead.
    optimizer = ScipyOptimizerInterface(
        loss, equalities=equalities, inequalities=inequalities, method='SLSQP')
    with tf.compat.v1.Session() as session:
      optimizer.minimize(session)
    # The value of vector should now be [1., 1.].
    ```
    """

    _DEFAULT_METHOD = "L-BFGS-B"

    def _minimize(
        self,
        initial_val,
        loss_grad_func,
        equality_funcs,
        equality_grad_funcs,
        inequality_funcs,
        inequality_grad_funcs,
        packed_bounds,
        step_callback,
        optimizer_kwargs,
    ):
        def loss_grad_func_wrapper(x):
            # SciPy's L-BFGS-B Fortran implementation requires gradients as doubles.
            loss, gradient = loss_grad_func(x)
            return loss, gradient.astype("float64")

        optimizer_kwargs = dict(optimizer_kwargs.items())
        method = optimizer_kwargs.pop("method", self._DEFAULT_METHOD)

        constraints = []
        for func, grad_func in zip(equality_funcs, equality_grad_funcs):
            constraints.append({"type": "eq", "fun": func, "jac": grad_func})
        for func, grad_func in zip(inequality_funcs, inequality_grad_funcs):
            constraints.append({"type": "ineq", "fun": func, "jac": grad_func})

        minimize_args = [loss_grad_func_wrapper, initial_val]
        minimize_kwargs = {
            "jac": True,
            "callback": step_callback,
            "method": method,
            "constraints": constraints,
            "bounds": packed_bounds,
        }

        for kwarg in minimize_kwargs:
            if kwarg in optimizer_kwargs:
                if kwarg == "bounds":
                    # Special handling for 'bounds' kwarg since ability to specify bounds
                    # was added after this module was already publicly released.
                    raise ValueError(
                        "Bounds must be set using the var_to_bounds argument"
                    )
                raise ValueError(
                    "Optimizer keyword arg '{}' is set "
                    "automatically and cannot be injected manually".format(kwarg)
                )

        minimize_kwargs.update(optimizer_kwargs)

        import scipy.optimize  # pylint: disable=g-import-not-at-top

        result = scipy.optimize.minimize(*minimize_args, **minimize_kwargs)

        message_lines = [
            "Optimization terminated with:",
            "  Message: %s",
            "  Objective function value: %f",
        ]
        message_args = [result.message, result.fun]
        if hasattr(result, "nit"):
            # Some optimization methods might not provide information such as nit and
            # nfev in the return. Logs only available information.
            message_lines.append("  Number of iterations: %d")
            message_args.append(result.nit)
        if hasattr(result, "nfev"):
            message_lines.append("  Number of functions evaluations: %d")
            message_args.append(result.nfev)
        tf.logging.info("\n".join(message_lines), *message_args)

        return result["x"]


def _accumulate(list_):
    total = 0
    yield total
    for x in list_:
        total += x
        yield total


def _get_shape_tuple(tensor):
    return tuple(tensor.shape)


def _prod(array):
    prod = 1
    for value in array:
        prod *= value
    return prod


def _compute_gradients(tensor, var_list):
    grads = tf.gradients(tensor, var_list)
    # tf.gradients sometimes returns `None` when it should return 0.
    return [
        grad if grad is not None else tf.zeros_like(var)
        for var, grad in zip(var_list, grads)
    ]
